import torch
import torch.nn as nn
import logging
from torch.nn import functional as F
import numpy as np
import math

from src.models_classifier import CosineLinear
from timm.models import create_model

logger = logging.getLogger()

hidden_dim_dict = {
    'vit_tiny_patch16_224_in21k':192,
    'resnet18':512,
    'vit_base_patch16_224':768,
    'vit_base_patch16_224_in21k':768,
    'deit_small_patch16_224': 384,
}

class IC_model(nn.Module):

    def __init__(self, output_dim, params):
        super(IC_model, self).__init__()
        self.params = params
        self.hidden_dim = hidden_dim_dict[params.backbone]
        self.output_dim = output_dim
        if 'vit' in params.backbone or 'deit' in params.backbone:
            if self.params.is_MEMO:
                self.encoder_general = create_model(model_name=params.backbone+'_general', # 'vit_base_patch16_224', 'vit_tiny_patch16_224_in21k'
                                            pretrained=True,
                                            num_classes=1000,
                                            drop_rate=0.0,
                                            drop_path_rate=0.0,
                                            drop_block_rate=None)
                self.encoder_specific = nn.ModuleList()
            else:
                if 'deit' in params.backbone:       
                    self.encoder = create_model(model_name=params.backbone,
                                            pretrained=False,
                                            num_classes=1000,
                                            drop_rate=0.0,
                                            drop_path_rate=0.0,
                                            drop_block_rate=None,
                                            use_adapter=self.params.use_adapter)
                    checkpoint = torch.load('.deit_pretrained/best_checkpoint.pth', map_location='cpu')
                    target = self.encoder.state_dict()
                    pretrain = checkpoint['model']
                    transfer, missing = {}, []
                    for k, _ in target.items():
                        if k in pretrain and 'head' not in k:
                            transfer[k] = pretrain[k]
                        else:
                            missing.append(k)
                    target.update(transfer)
                    # Loading Deit model pre-trained with 611 classes
                    self.encoder.load_state_dict(target)
                else:
                    self.encoder = create_model(model_name=params.backbone, # 'vit_base_patch16_224', 'vit_tiny_patch16_224_in21k'
                                                pretrained=True,
                                                num_classes=1000,
                                                drop_rate=0.0,
                                                drop_path_rate=0.0,
                                                drop_block_rate=None)
        elif 'resnet18' in params.backbone:
            pass
            # self.encoder = resnet18(nclasses=100)
        else:
            raise NotImplementedError()

        if self.params.is_rand_init and ('vit' in params.backbone or 'deit' in params.backbone):
            self.encoder.apply(self.encoder._init_weights)
        self.classifier = CosineLinear(self.hidden_dim, self.output_dim)

    def forward(self, X, return_feat=False, task_id=-1):

        features = self.forward_encoder(X, task_id=task_id)

        logits = self.forward_classifier(features)
        if return_feat:
            return logits, features
        return logits
        
    def forward_encoder(self, X, task_id=-1):
        if 'vit' in self.params.backbone or 'deit' in self.params.backbone:
            if self.params.is_MEMO:
                base_features = self.encoder_general(X)
                features = [extractor(base_features)['pre_logits'] for extractor in self.encoder_specific] # [CLS] feature
                features = torch.cat(features,dim=-1)
            # The task id is known during training
            elif self.params.is_BaCE and self.params.BaCE_prompt_tuning and task_id>0: 
                # add prompt
                dummy_ids = torch.tensor(list(range(task_id*self.params.BaCE_prompt_len))).cuda()
                prompt_features = self.soft_prompt[dummy_ids]
                features = self.encoder.forward(X, task_id=task_id, prompt_features=prompt_features)['pre_logits'] # [CLS] feature
            else:
                features = self.encoder.forward(X)['pre_logits'] # [CLS] feature
        elif 'resnet18' in self.params.backbone:
            features = self.encoder(X, returnt='features')
        return features

    def forward_classifier(self, features):
        logits = self.classifier(features)
        return logits
    
    def batch_loss_LwF(self, inputs, logits, labels, refer_model, task_id):
        '''
            Method: Learning with Forgetting
        '''

        # classification loss
        ce_loss = nn.CrossEntropyLoss()(logits.view(-1, logits.shape[-1]), 
                                labels.flatten().long())
        
        # distillation loss
        refer_dim = refer_model.classifier.output_dim
        with torch.no_grad():
            refer_model.eval()
            refer_features = refer_model.forward_encoder(inputs)
            refer_logits = refer_model.forward_classifier(refer_features)
            ref_score = refer_logits.div(self.params.LWF_temperature).softmax(-1)

        old_class_score = F.log_softmax(logits/self.params.LWF_temperature,dim=-1)[:,:refer_dim]

        distill_loss = nn.KLDivLoss(reduction='batchmean')(old_class_score, ref_score)

        total_loss = ce_loss + self.params.LWF_lambda*distill_loss

        return total_loss, ce_loss.item(), self.params.LWF_lambda*distill_loss.item()

    def batch_loss_EWC(self, logits, labels, refer_model, fisher, task_id):
        '''
            Method: Elastic Weight Consolidation
        '''

        # classification loss
        ce_loss = nn.CrossEntropyLoss()(logits.view(-1, logits.shape[-1]), 
                                labels.flatten().long())
        
        # distillation loss
        distill_loss = 0
        for (name,param),(_,param_old) in zip(self.named_parameters(),refer_model.named_parameters()):
            # Only Regularize Encoder
            if 'encoder' in name:
                distill_loss+=torch.sum(fisher[name]*(param_old-param).pow(2))/2

        total_loss = ce_loss + self.params.EWC_lambda*distill_loss

        return total_loss, ce_loss.item(), self.params.EWC_lambda*distill_loss.item()

    def batch_loss(self, logits, labels):
        '''
            Cross-Entropy Loss
        '''

        # classification loss
        ce_loss = nn.CrossEntropyLoss()(logits.view(-1, logits.shape[-1]), 
                                labels.flatten().long())
        total_loss = ce_loss
        return total_loss, ce_loss.item(), 0
    
    def batch_loss_clser(self, inputs, logits, labels, refer_model, refer_model_2):
        # classification loss
        ce_loss = nn.CrossEntropyLoss()(logits.view(-1, logits.shape[-1]), 
                                labels.flatten().long())
        
        buf_input = inputs[inputs.shape[0]//2:]
        buf_labels = labels[inputs.shape[0]//2:]

        # (2) Ditsillation loss
        with torch.no_grad():
            refer_model.eval()
            refer_features_1 = refer_model.forward_encoder(buf_input)
            stable_model_logits = refer_model.forward_classifier(refer_features_1)
            stable_model_prob = F.softmax(stable_model_logits, 1)
            refer_model_2.eval()
            refer_features_2 = refer_model_2.forward_encoder(buf_input)
            plastic_model_logits = refer_model_2.forward_classifier(refer_features_2)
            plastic_model_prob = F.softmax(plastic_model_logits, 1)
            
            label_mask = F.one_hot(buf_labels, num_classes=stable_model_logits.shape[-1]) > 0
            sel_idx = stable_model_prob[label_mask] > plastic_model_prob[label_mask]
            sel_idx = sel_idx.unsqueeze(1)

            ema_logits = torch.where(
                sel_idx,
                stable_model_logits,
                plastic_model_logits,
            )

        distll_loss = nn.MSELoss()(logits[inputs.shape[0]//2:,:], ema_logits.detach())

        total_loss = ce_loss + self.params.CLSER_lambda*distll_loss

        return total_loss, ce_loss.item(), self.params.CLSER_lambda*distll_loss.item()

    def batch_loss_distill_BaCE(self, inputs, logits, labels, index, refer_model, task_id, num_task):
        '''
            Cross-Entropy Loss + Distillation loss(KLDivLoss) with Causal Effect Retrieving
        '''
        total_loss = 0
        refer_dims = self.classifier.fc0.output_dim
        all_dims = self.classifier.output_dim

        new_mask = labels>=refer_dims # new classes
        old_mask = labels<refer_dims # old classes
        with torch.no_grad():
            refer_model.eval()
            refer_features = refer_model.forward_encoder(inputs,task_id=task_id-1)
            refer_logits = refer_model.forward_classifier(refer_features)

        # joint ce_loss
        anchor_score_all = F.softmax(logits, dim=-1)

        # Using:  Moving Average for computing joint score
        assert hasattr(self,'score_dict')
        assert hasattr(self,'knn_id_dict')
        assert hasattr(self,'knn_dist_dict')
        if self.params.BaCE_W0 == 1:
            joint_score_all = anchor_score_all
        else:
            joint_score_all = torch.zeros_like(anchor_score_all).to(anchor_score_all.device)
            for i, idx in enumerate(index):
                if isinstance(idx,torch.Tensor):
                    idx = idx.item()
                if old_mask[i]:
                    joint_score_all[i] = anchor_score_all[i]
                    continue
                # Update cache
                self.score_dict[idx] = anchor_score_all[i].clone().detach().cpu()
                # Compute joint score
                select_knn_idx = [] 
                select_knn_idx_dist = [] 
                cur_W0 = self.params.BaCE_W0_bg+(self.params.BaCE_W0-self.params.BaCE_W0_bg)*task_id/(num_task-1)
                joint_score_all[i] += cur_W0*anchor_score_all[i]
                assert idx in self.knn_id_dict.keys()
                assert idx in self.knn_dist_dict.keys()
                for j, knn_idx in enumerate(self.knn_id_dict[idx]):
                    if self.knn_dist_dict[idx][j] < self.params.BaCE_dist_threshold:
                        select_knn_idx.append(knn_idx)
                        select_knn_idx_dist.append(self.knn_dist_dict[idx][j])
                for j, knn_idx in enumerate(select_knn_idx):
                    if knn_idx in self.score_dict.keys():
                        if self.params.BaCE_weight_type == 'average':
                            joint_score_all[i] += (self.score_dict[knn_idx].cuda())*(1-cur_W0)*(1/len(select_knn_idx))
                        elif self.params.BaCE_weight_type == 'distance':
                            joint_score_all[i] += (self.score_dict[knn_idx].cuda())*(1-cur_W0)*(select_knn_idx_dist[j]/np.sum(select_knn_idx_dist))
                        else:
                            raise NotImplementedError()
                    else: # first epoch
                        joint_score_all[i] += (1-cur_W0)*(torch.ones(all_dims).div(all_dims).cuda())/len(select_knn_idx)
        
        ce_loss = F.nll_loss(joint_score_all.add(1e-10).log(), labels)

        if self.params.BaCE_lambda_type == 'constant':
            BaCE_lambda = self.params.BaCE_lambda
        elif self.params.BaCE_lambda_type == 'linear':
            BaCE_lambda = self.params.BaCE_lambda*task_id
        elif self.params.BaCE_lambda_type == 'sqrt':
            BaCE_lambda = self.params.BaCE_lambda*np.sqrt(task_id)

        if old_mask.sum().item()>0:
            distill_loss_old = F.mse_loss(logits[old_mask,:refer_dims], refer_logits[old_mask,:refer_dims])
            if new_mask.sum().item()>0:
                distill_loss_new = nn.KLDivLoss(reduction='batchmean')(
                                        F.log_softmax(logits/2,dim=-1)[new_mask,:refer_dims], 
                                        refer_logits[new_mask,:refer_dims].div(2).softmax(-1)
                                    )
                distill_loss = BaCE_lambda*distill_loss_new+distill_loss_old
            else:
                distill_loss = distill_loss_old
        else:   
            distill_loss_new = nn.KLDivLoss(reduction='batchmean')(
                                    F.log_softmax(logits/2,dim=-1)[:,:refer_dims], 
                                    refer_logits[:,:refer_dims].div(2).softmax(-1)
                                )
            distill_loss = BaCE_lambda*distill_loss_new

        total_loss = ce_loss + distill_loss

        return total_loss, ce_loss.item(), distill_loss.item()
    
    def batch_loss_lucir(self, inputs, features, logits, labels, refer_model):
        '''
            Cross-Entropy Loss + Distillation Loss(CosineEmbeddingLoss) + MarginRankingLoss 
        '''
        total_loss = 0
        refer_dims = self.classifier.fc0.output_dim
        all_dims = self.classifier.output_dim

        # (1) CE loss
        ce_loss = nn.CrossEntropyLoss()(logits, labels.flatten().long())

        # (2) distill loss
        lw_distill = self.params.lucir_lw_distill*math.sqrt(refer_dims/(all_dims-refer_dims))
        with torch.no_grad():
            refer_model.eval()
            refer_features = refer_model.forward_encoder(inputs)
        distill_loss = lw_distill * nn.CosineEmbeddingLoss()(
            features.view(-1, features.shape[-1]),
            refer_features.view(-1, features.shape[-1]),
            torch.ones(features.shape[0]).cuda()
        )

        # (3) MR loss
        # 旧类别+O类别，无replay时只有O类别，同distill_mask
        mr_mask = labels<refer_dims
        labels_masked = labels.flatten().long()[mr_mask].view(-1, 1)
        mr_logits = logits.view(-1, logits.shape[-1])[mr_mask]
        gt_scores = mr_logits.gather(1, labels_masked).repeat(1, self.params.lucir_K)
        max_novel_scores = mr_logits[:, refer_dims:].topk(self.params.lucir_K, dim=1)[0]

        count = gt_scores.size(0)
        if count > 0:
            mr_loss = nn.MarginRankingLoss(margin=self.params.lucir_mr_dist)(gt_scores.view(-1), \
                max_novel_scores.view(-1), torch.ones(count*self.params.lucir_K).cuda()) * self.params.lucir_lw_mr
        else:
            mr_loss = torch.tensor(0., requires_grad=True).cuda()

        total_loss = ce_loss + distill_loss + mr_loss

        return total_loss, ce_loss.item(), distill_loss.item()+mr_loss.item()

    def batch_loss_podnet(self, inputs, features, logits, labels, refer_model):
        '''
            NCA Loss/Cross-Entropy Loss + Distillation Loss(CosineEmbeddingLoss+L2_norm)
        '''
        total_loss = 0
        refer_dims = self.classifier.fc0.output_dim
        all_dims = self.classifier.output_dim

        # (1) NCA loss
        lsc_mask = labels >= refer_dims

        if torch.sum(lsc_mask.float()) == 0:
            lsc_loss = torch.tensor(0., requires_grad=True).cuda()
        elif self.params.podnet_is_nca:
            similarities = logits.view(-1, all_dims)[lsc_mask]
            targets = labels.flatten().long()[lsc_mask]
            margins = torch.zeros_like(similarities)
            margins[torch.arange(margins.shape[0]), targets] = self.params.podnet_nca_margin
            similarities = self.params.podnet_nca_scale * (similarities - self.params.podnet_nca_margin)

            similarities = similarities - similarities.max(1)[0].view(-1, 1)  # Stability

            disable_pos = torch.zeros_like(similarities)
            disable_pos[torch.arange(len(similarities)),
                        targets] = similarities[
                            torch.arange(len(similarities)), targets]

            numerator = similarities[torch.arange(similarities.shape[0]),
                                     targets]
            denominator = similarities - disable_pos

            losses = numerator - torch.log(torch.exp(denominator).sum(-1))

            lsc_loss = torch.mean(-losses)
        else:
            lsc_loss = nn.CrossEntropyLoss()(logits.view(
                -1, all_dims)[lsc_mask], labels.flatten().long()[lsc_mask])

        # (2) distill loss
        distill_mask = labels<refer_dims

        if torch.sum(distill_mask.float()) == 0:
            pod_flat_loss = torch.tensor(0., requires_grad=True).cuda()
        else:
            with torch.no_grad():
                refer_model.eval()
                refer_features = refer_model.forward_encoder(inputs)

            lw_pod_flat = self.params.podnet_lw_pod_flat * math.sqrt(refer_dims/(all_dims-refer_dims))
            pod_flat_loss = lw_pod_flat * nn.CosineEmbeddingLoss(reduction='mean')(
                                features[distill_mask].view(-1, features.shape[-1]),
                                refer_features[distill_mask].view(-1, features.shape[-1]),
                                torch.ones(distill_mask.nonzero().size(0)).cuda())

        total_loss = lsc_loss + pod_flat_loss

        return total_loss, lsc_loss.item(), pod_flat_loss.item()